# ---------------------------------------------------------------------
# Copyright (c) 2025 Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
# ---------------------------------------------------------------------

from __future__ import annotations

import os
from pathlib import Path

import requests
from pydantic import Field, ValidationInfo, model_validator
from qai_hub.util.session import create_session

from qai_hub_models.configs._info_yaml_enums import (
    MODEL_DOMAIN,
    MODEL_LICENSE,
    MODEL_STATUS,
    MODEL_TAG,
    MODEL_USE_CASE,
)
from qai_hub_models.configs._info_yaml_llm_details import LLM_CALL_TO_ACTION, LLMDetails
from qai_hub_models.configs.code_gen_yaml import QAIHMModelCodeGen
from qai_hub_models.scorecard import ScorecardDevice
from qai_hub_models.utils.asset_loaders import ASSET_CONFIG, QAIHM_WEB_ASSET
from qai_hub_models.utils.base_config import BaseQAIHMConfig
from qai_hub_models.utils.path_helpers import (
    MODEL_IDS,
    MODELS_PACKAGE_NAME,
    QAIHM_MODELS_ROOT,
    QAIHM_PACKAGE_NAME,
    QAIHM_PACKAGE_ROOT,
    _get_qaihm_models_root,
)

__all__ = [
    "MODEL_DOMAIN",
    "MODEL_STATUS",
    "MODEL_TAG",
    "MODEL_USE_CASE",
    "QAIHMModelInfo",
]


class QAIHMModelInfo(BaseQAIHMConfig):
    """Schema & loader for model info.yaml."""

    # Name of the model as it will appear on the website.
    # Should have dashes instead of underscores and all
    # words capitalized. For example, `Whisper-Base-En`.
    name: str

    # Name of the model's folder within the repo.
    id: str

    # Whether or not the model is published on the website.
    # This should be set to public unless the model has poor accuracy/perf.
    status: MODEL_STATUS

    # A brief catchy headline explaining what the model does and why it may be interesting
    headline: str

    # The domain the model is used in such as computer vision, audio, etc.
    domain: MODEL_DOMAIN

    # A 2-3 sentence description of how the model can be used.
    description: str

    # What task the model is used to solve, such as object detection, classification, etc.
    use_case: MODEL_USE_CASE

    # A list of applicable tags to add to the model
    tags: list[MODEL_TAG]

    # A list of real-world applicaitons for which this model could be used.
    # This is free-from and almost anything reasonable here is fine.
    applicable_scenarios: list[str]

    # A list of other similar models in the repo.
    # Typically, any model that performs the same task is fine.
    # If nothing fits, this can be left blank. Limit to 3 models.
    related_models: list[str]

    # A list of device types for which this model could be useful.
    # If unsure what to put here, default to `Phone` and `Tablet`.
    form_factors: list[ScorecardDevice.FormFactor]

    # Whether the model has a static image uploaded in S3. All public models must have this.
    has_static_banner: bool

    # Whether the model has an animated asset uploaded in S3. This is optional.
    has_animated_banner: bool

    # CodeGen options from code-gen.yaml in the model's folder.
    code_gen_config: QAIHMModelCodeGen = Field(default_factory=QAIHMModelCodeGen)

    # A list of datasets for which the model has pre-trained checkpoints
    # available as options in `model.py`. Typically only has one entry.
    dataset: list[str]

    # A list of a few technical details about the model.
    #   Model checkpoint: The name of the downloaded model checkpoint file.
    #   Input resolution: The size of the model's input. For example, `2048x1024`.
    #   Number of parameters: The number of parameters in the model.
    #   Model size: The file size of the downloaded model asset.
    #       This and `Number of parameters` should be auto-generated by running `python qai_hub_models/scripts/autofill_info_yaml.py -m <model_name>`
    #   Number of output classes: The number of classes the model can classify or annotate.
    technical_details: dict[str, str | int | float]

    # The license type of the original model repo.
    license_type: MODEL_LICENSE

    # Device form factors for which we don't publish performance data.
    private_perf_form_factors: list[ScorecardDevice.FormFactor] | None = None

    # Some models are made by company
    model_maker_id: str | None = None

    # Link to the research paper where the model was first published. Usually an arxiv link.
    research_paper: str | None = None

    # The title of the research paper.
    research_paper_title: str | None = None

    # A link to the original github repo with the model's code.
    source_repo: str | None = None

    # A link to the model's license. Most commonly found in the github repo it was cloned from.
    license: str | None = None

    # Whether the model is compatible with the IMSDK Plugin for IOT devices
    imsdk_supported: bool = False

    # A link to the AIHub license, unless the license is more restrictive like GPL.
    # In that case, this should point to the same as the model license.
    deploy_license: str | None = None

    # Should be set to `ai-hub-models-license`, unless the license is more restrictive like GPL.
    # In that case, this should be the same as the model license.
    deploy_license_type: MODEL_LICENSE | None = None

    # If set, model assets shouldn't distributed.
    restrict_model_sharing: bool = False

    # If status is private, this must have a reference to an internal issue with an explanation.
    status_reason: str | None = None

    # If the model outputs class indices, this field should be set and point
    # to a file in `qai_hub_models/labels`, which specifies the name for each index.
    labels_file: str | None = None

    # It is a large language model (LLM) or not.
    model_type_llm: bool = False

    # Add per device, download, app and if the model is available for purchase.
    llm_details: LLMDetails | None = None

    @model_validator(mode="after")
    def check_fields(self, info: ValidationInfo) -> QAIHMModelInfo:
        """Returns false with a reason if the info spec for this model is not valid."""
        validate_urls_exist: bool = info.context is not None and bool(
            info.context.get("validate_urls_exist", False)
        )

        # Validate ID
        if self.id not in MODEL_IDS:
            raise ValueError(f"{self.id} is not a valid QAI Hub Models ID.")
        if " " in self.id or "-" in self.id:
            raise ValueError("Model IDs cannot contain spaces or dashes.")
        if self.id.lower() != self.id:
            raise ValueError("Model IDs must be lowercase.")

        # Validate (used as repo name for HF as well)
        if " " in self.name:
            raise ValueError("Model Name must not have a space.")
        if "_" in self.name:
            raise ValueError("Model Name should use dashes (-) instead of underscores.")

        # Headline should end with period
        if not self.headline.endswith("."):
            raise ValueError("Model headlines must end with a period.")

        # Validate related models are present
        for r_model in self.related_models:
            if r_model not in MODEL_IDS:
                raise ValueError(f"Related model {r_model} is not a valid model ID.")
            if r_model == self.id:
                raise ValueError(f"Model {r_model} cannot be related to itself.")
            # TODO: https://github.com/qcom-ai-hub/tetracode/issues/15078
            # Add validation to make sure related models are not private if this
            # model is public.

        # If paper is arxiv, it should be an abs link
        if (
            self.research_paper is not None
            and self.research_paper.startswith("https://arxiv.org/")
            and "/abs/" not in self.research_paper
        ):
            raise ValueError(
                "Arxiv links should be `abs` links, not link directly to pdfs."
            )

        # Whether this model has a page on the website
        model_is_available = self.status == MODEL_STATUS.PUBLIC
        # Whether this model can actually be downloaded by the public
        model_is_accessible = not self.restrict_model_sharing

        # License validation
        if not self.deploy_license and model_is_available and model_is_accessible:
            raise ValueError("deploy_license cannot be empty")
        if not self.deploy_license_type and model_is_available and model_is_accessible:
            raise ValueError("deploy_license_type cannot be empty")
        if self.license_type.url is not None and self.license != self.license_type.url:
            raise ValueError(
                f"License {self.license_type!s} must have URL {self.license_type.url}"
            )
        if self.license_type.deploy_license is None and model_is_available:
            raise ValueError(
                f"Models with license {self.license_type!s} cannot be published"
            )
        if self.deploy_license_type is not None:
            if self.license_type.deploy_license != self.deploy_license_type:
                raise ValueError(
                    f"License {self.license_type!s} must be paired with a deployment license of type {self.license_type.deploy_license}"
                )
            if (
                self.deploy_license_type.url is not None
                and self.deploy_license != self.deploy_license_type.url
            ):
                raise ValueError(
                    f"License {self.deploy_license_type!s} must have URL {self.deploy_license_type.url}"
                )
            if (
                self.deploy_license_type == self.license_type
                and self.deploy_license != self.license
            ):
                raise ValueError(
                    "If a model's source license and deployment license types are the same, their URLs must also be the same."
                )

        # Status Reason
        if self.status == MODEL_STATUS.PRIVATE and not self.status_reason:
            raise ValueError(
                "Private models must set `status_reason` in info.yaml with a link to the related issue."
            )

        if self.status == MODEL_STATUS.PUBLIC and self.status_reason:
            raise ValueError(
                "`status_reason` in info.yaml should not be set for public models."
            )

        # Labels file
        if (
            validate_urls_exist
            and self.labels_file is not None
            and not os.path.exists(ASSET_CONFIG.get_labels_file_path(self.labels_file))
        ):
            raise ValueError(f"Invalid labels file: {self.labels_file}")

        # Required assets exist
        if self.status == MODEL_STATUS.PUBLIC:
            if not os.path.exists(self.get_package_path() / "info.yaml"):
                raise ValueError("All public models must have an info.yaml")

            # If a model is not running in scorecard and is public,
            # there must be a perf yaml
            if (not self.code_gen_config.runs_in_scorecard) and not os.path.exists(
                self.get_package_path() / "perf.yaml"
            ):
                raise ValueError(
                    "All public models that don't run in scorecard must have a perf.yaml"
                )

            if not self.code_gen_config.supports_at_least_1_runtime:
                raise ValueError("Public models must support at least one export path")

            if not self.has_static_banner:
                raise ValueError("Public models must have a static asset.")

        session = create_session()
        if validate_urls_exist and self.has_static_banner:
            static_banner_url = ASSET_CONFIG.get_web_asset_url(
                self.id, QAIHM_WEB_ASSET.STATIC_IMG
            )
            if session.head(static_banner_url).status_code != requests.codes.ok:
                raise ValueError(f"Static banner is missing at {static_banner_url}")
        if validate_urls_exist and self.has_animated_banner:
            animated_banner_url = ASSET_CONFIG.get_web_asset_url(
                self.id, QAIHM_WEB_ASSET.ANIMATED_MOV
            )
            if session.head(animated_banner_url).status_code != requests.codes.ok:
                raise ValueError(f"Animated banner is missing at {animated_banner_url}")

        expected_qaihm_repo = Path("qai_hub_models") / "models" / self.id
        if expected_qaihm_repo != ASSET_CONFIG.get_qaihm_repo(self.id):
            raise ValueError("QAIHM repo not pointing to expected relative path")

        expected_example_use = f"qai_hub_models/models/{self.id}#example--usage"
        if expected_example_use != ASSET_CONFIG.get_example_use(self.id):
            raise ValueError(
                "Example-usage field not pointing to expected relative path"
            )

        # Check that model_type_llm and llm_details fields
        if self.model_type_llm:
            if not self.llm_details:
                raise ValueError("llm_details must be set if model type is LLM")

            model_is_available = self.llm_details.call_to_action not in [
                LLM_CALL_TO_ACTION.CONTACT_FOR_PURCHASE,
                LLM_CALL_TO_ACTION.COMING_SOON,
                LLM_CALL_TO_ACTION.CONTACT_US,
            ]

            # Download URL can only be validated in a scope with a model ID, so this
            # is validated here rather than on the LLM details class' validator.
            if self.llm_details.devices and validate_urls_exist:
                for device_runtime_config_mapping in self.llm_details.devices.values():
                    for runtime_detail in device_runtime_config_mapping.values():
                        version = runtime_detail.model_download_url.split("/")[0][1:]
                        relative_path = "/".join(
                            runtime_detail.model_download_url.split("/")[1:]
                        )
                        model_download_url = ASSET_CONFIG.get_model_asset_url(
                            self.id, version, relative_path
                        )
                        if (
                            session.head(model_download_url).status_code
                            != requests.codes.ok
                        ):
                            raise ValueError(
                                f"Download URL is missing at {runtime_detail.model_download_url}"
                            )
        elif self.llm_details:
            raise ValueError("Model type must be LLM if llm_details is set")

        return self

    def get_package_name(self):
        return f"{QAIHM_PACKAGE_NAME}.{MODELS_PACKAGE_NAME}.{self.id}"

    def get_package_path(self, root: Path = QAIHM_PACKAGE_ROOT):
        return _get_qaihm_models_root(root) / self.id

    def get_model_definition_path(self):
        return os.path.join(
            ASSET_CONFIG.get_qaihm_repo(self.id, relative=False), "model.py"
        )

    def get_demo_path(self):
        return os.path.join(
            ASSET_CONFIG.get_qaihm_repo(self.id, relative=False), "demo.py"
        )

    def get_labels_file_path(self):
        if self.labels_file is None:
            return None
        return ASSET_CONFIG.get_labels_file_path(self.labels_file)

    def get_info_yaml_path(self, root: Path = QAIHM_PACKAGE_ROOT):
        return self.get_package_path(root) / "info.yaml"

    def get_hf_pipeline_tag(self):
        return self.use_case.map_to_hf_pipeline_tag()

    def get_hugging_face_metadata(self, root: Path = QAIHM_PACKAGE_ROOT):
        # Get the metadata for huggingface model cards.
        hf_metadata: dict[str, str | list[str]] = {}
        hf_metadata["library_name"] = "pytorch"
        # We only tag Hugging Face models with the specific license name if the source is copyleft.
        # Most models are tagged with the "other" license on HF because they use the AI Hub Models license.
        hf_metadata["license"] = (
            # 'Unlicensed' will appear only if this model is not public.
            # All models are validated to have a deployment license if they are public.
            self.deploy_license_type or MODEL_LICENSE.UNLICENSED
        ).huggingface_name
        hf_metadata["tags"] = [tag.name.lower() for tag in self.tags] + ["android"]
        hf_metadata["pipeline_tag"] = self.get_hf_pipeline_tag()
        return hf_metadata

    def get_model_details(self):
        # Model details.
        details = (
            "- **Model Type:** "
            + self.use_case.__str__().lower().capitalize()
            + "\n- **Model Stats:**"
        )
        for name, val in self.technical_details.items():
            details += f"\n  - {name}: {val}"
        return details

    def get_perf_yaml_path(self, root: Path = QAIHM_PACKAGE_ROOT):
        return self.get_package_path(root) / "perf.yaml"

    def get_code_gen_yaml_path(self, root: Path = QAIHM_PACKAGE_ROOT):
        return self.get_package_path(root) / "code-gen.yaml"

    def get_readme_path(self, root: Path = QAIHM_PACKAGE_ROOT):
        return self.get_package_path(root) / "README.md"

    def get_hf_model_card_path(self, root: Path = QAIHM_PACKAGE_ROOT):
        return self.get_package_path(root) / "HF_MODEL_CARD.md"

    def get_requirements_path(self, root: Path = QAIHM_PACKAGE_ROOT):
        return self.get_package_path(root) / "requirements.txt"

    def has_model_requirements(self, root: Path = QAIHM_PACKAGE_ROOT):
        return os.path.exists(self.get_requirements_path(root))

    def get_web_url(self, website_url: str = ASSET_CONFIG.models_website_url) -> str:
        return f"{website_url}/models/{self.id}"

    @property
    def is_gen_ai_model(self) -> bool:
        return MODEL_TAG.LLM in self.tags or MODEL_TAG.GENERATIVE_AI in self.tags

    @classmethod
    def from_model(cls: type[QAIHMModelInfo], model_id: str) -> QAIHMModelInfo:
        schema_path = QAIHM_MODELS_ROOT / model_id / "info.yaml"
        if not os.path.exists(schema_path):
            raise ValueError(f"{model_id} does not exist")
        info = cls.from_yaml(schema_path)
        info.code_gen_config = QAIHMModelCodeGen.from_model(model_id)
        return info

    def to_model_yaml(self, write_code_gen=True) -> tuple[Path, Path | None]:
        info_path = QAIHM_MODELS_ROOT / self.id / "info.yaml"
        code_gen_path = None
        self.to_yaml(
            path=info_path,
            exclude=["code_gen_config"],
        )
        if write_code_gen:
            code_gen_path = self.code_gen_config.to_model_yaml(self.id)
        return info_path, code_gen_path
